package face5wap;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

public class dl4jGANComputerVision {
    private static final Logger log = LoggerFactory.getLogger(dl4jGANComputerVision.class);

    private static final int batchSizePerWorker = 200;
    private static final int batchSizePred = 500;
    private static final int labelIndex = 784;
    private static final int numClasses = 10; // Using Softmax.
    private static final int numClassesDis = 1; // Using Sigmoid.
    private static final int numFeatures = 784;
    private static final int numIterations = 10000;
    private static final int numGenSamples = 10; // This will be a grid so effectively we get {numGenSamples * numGenSamples} samples.
    private static final int numLinesToSkip = 0;
    private static final int numberOfTheBeast = 666;
    private static final int printEvery = 2;
    private static final int saveEvery = 100;
    private static final int tensorDimOneSize = 28;
    private static final int tensorDimTwoSize = 28;
    private static final int tensorDimThreeSize = 1;
    private static final int zSize = 2;

    private static final double dis_learning_rate = 0.002;
    private static final double frozen_learning_rate = 0.0;
    private static final double gen_learning_rate = 0.004;

    private static final String delimiter = ",";
    private static final String resPath = "E:face/model/";
    private static final String newLine = "\n";
    private static final String dataSetName = "mnist";

    private static final boolean useGpu = false;

    public static void main(String[] args) throws Exception {
        //new dl4jGANComputerVision().GAN(args);
        gan();
    }

    private static void  gan() throws Exception {
        //训练使用的后台 cup还是gpu
        System.out.println(Nd4j.getBackend());
        Nd4j.getMemoryManager().setAutoGcWindow(5000);

        log.info("Unfrozen discriminator!");
        // dis网络用于生成数据真假的判别
        //1.接受 gen生成的数据和真实的数据
        //2.learningrate 不是固定的。
        ComputationGraph dis = new ComputationGraph(new NeuralNetConfiguration.Builder()
                //始终使用工作空间
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                //推断空间
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                //优化算法  - 数据 下降可导
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                //梯度标准化策略
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                //在 -1，1之间
                .gradientNormalizationThreshold(1.0)
                .l2(0.0001)
                //激活函数
                .activation(Activation.TANH)
                //Weight初始化
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("dis_input_layer_0")
                .setInputTypes(InputType.convolutionalFlat(tensorDimOneSize, tensorDimTwoSize, tensorDimThreeSize))
                .addLayer("dis_batch_layer_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .build(), "dis_input_layer_0")
                .addLayer("dis_conv2d_layer_2", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(1)
                        .nOut(64)
                        .build(), "dis_batch_layer_1")
                .addLayer("dis_maxpool_layer_3", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "dis_conv2d_layer_2")
                .addLayer("dis_conv2d_layer_4", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(128)
                        .build(), "dis_maxpool_layer_3")
                .addLayer("dis_maxpool_layer_5", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "dis_conv2d_layer_4")
                .addLayer("dis_dense_layer_6", new DenseLayer.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "dis_maxpool_layer_5")
                .addLayer("dis_output_layer_7", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nOut(numClassesDis)
                        .activation(Activation.SIGMOID)
                        .build(), "dis_dense_layer_6")
                .setOutputs("dis_output_layer_7")
                .build());
        dis.init();
        System.out.println(dis.summary());
        System.out.println(Arrays.toString(dis.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));

        log.info("Frozen generator!");
        // gen网络用于生成假数据
        //1.从噪点数据中生成类似真的数据
        //2.learningrate 是固定的
        ComputationGraph gen = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .l2(0.0001)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("gen_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gen_batch_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gen_input_layer_0")
                .addLayer("gen_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "gen_batch_1")
                .addLayer("gen_dense_layer_3", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(7 * 7 * 128)
                        .build(), "gen_dense_layer_2")
                .addLayer("gen_batch_4", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gen_dense_layer_3")
                .inputPreProcessor("gen_deconv2d_5", new FeedForwardToCnnPreProcessor(7, 7, 128))
                .addLayer("gen_deconv2d_5", new Upsampling2D.Builder(2)
                        .build(), "gen_batch_4")
                .addLayer("gen_conv2d_6", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(128)
                        .nOut(64)
                        .build(), "gen_deconv2d_5")
                .addLayer("gen_deconv2d_7", new Upsampling2D.Builder(2)
                        .build(), "gen_conv2d_6")
                .addLayer("gen_conv2d_8", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .activation(Activation.SIGMOID)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(1)
                        .build(), "gen_deconv2d_7")
                .setOutputs("gen_conv2d_8")
                .build());
        gen.init();
        System.out.println(gen.summary());
        System.out.println(Arrays.toString(gen.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));

        log.info("GAN with unfrozen generator and frozen discriminator!");
        // gan 络用于学习生成以假乱真的数据
        //1.从噪点数据中生成类似真的数据
        //2.gen的rate <> 0 ,dis的rate=0
        ComputationGraph gan = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .l2(0.0001)
                .graphBuilder()
                .addInputs("gan_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gan_batch_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_input_layer_0")
                .addLayer("gan_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "gan_batch_1")
                .addLayer("gan_dense_layer_3", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(7 * 7 * 128)
                        .build(), "gan_dense_layer_2")
                .addLayer("gan_batch_4", new BatchNormalization.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_dense_layer_3")
                .inputPreProcessor("gan_deconv2d_5", new FeedForwardToCnnPreProcessor(7, 7, 128))

                .addLayer("gan_deconv2d_5", new Upsampling2D.Builder(2)
                        .build(), "gan_batch_4")
                .addLayer("gan_conv2d_6", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nIn(128)
                        .nOut(64)
                        .build(), "gan_deconv2d_5")
                .addLayer("gan_deconv2d_7", new Upsampling2D.Builder(2)
                        .build(), "gan_conv2d_6")
                .addLayer("gan_conv2d_8", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .activation(Activation.SIGMOID)
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(1)
                        .build(), "gan_deconv2d_7")

                .addLayer("gan_dis_batch_layer_9", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_conv2d_8")
                .addLayer("gan_dis_conv2d_layer_10", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(1)
                        .nOut(64)
                        .build(), "gan_dis_batch_layer_9")
                .addLayer("gan_dis_maxpool_layer_11", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "gan_dis_conv2d_layer_10")
                .addLayer("gan_dis_conv2d_layer_12", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(128)
                        .build(), "gan_dis_maxpool_layer_11")
                .addLayer("gan_dis_maxpool_layer_13", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "gan_dis_conv2d_layer_12")
                .addLayer("gan_dis_dense_layer_14", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "gan_dis_maxpool_layer_13")
                .addLayer("gan_dis_output_layer_15", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(numClassesDis)
                        .activation(Activation.SIGMOID)
                        .build(), "gan_dis_dense_layer_14")
                .setOutputs("gan_dis_output_layer_15")
                .build());
        gan.init();
        System.out.println(gan.summary());
        System.out.println(Arrays.toString(gan.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));
        // cv 修改dis的输出，输出想要的数据类型
        ComputationGraph cv = new TransferLearning.GraphBuilder(dis)
                .fineTuneConfiguration(new FineTuneConfiguration.Builder()
                        .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                        .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                        .gradientNormalizationThreshold(1.0)
                        .activation(Activation.TANH)
                        .l2(0.0001)
                        .weightInit(WeightInit.XAVIER)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .seed(numberOfTheBeast)
                        .build())
                .setFeatureExtractor("dis_dense_layer_6")
                .removeVertexKeepConnections("dis_output_layer_7")
                .addLayer("dis_batch", new BatchNormalization.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(1024)
                        .nOut(1024)
                        .build(), "dis_dense_layer_6")
                .addLayer("dis_output_layer_7", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(1024)
                        .nOut(numClasses)
                        .activation(Activation.SOFTMAX)
                        .build(), "dis_batch")
                .build();
        System.out.println(cv.summary());
        System.out.println(Arrays.toString(cv.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));


        INDArray trainDataDis, trainDataGen, trainData;

        INDArray grid = Nd4j.linspace(-1l, 1l, numGenSamples);
        Collection<INDArray> z = new ArrayList<>();
        log.info("Creating some noise!");
        for (int i = 0; i < numGenSamples; i++) {
            for (int j = 0; j < numGenSamples; j++) {
                z.add(Nd4j.create(new double[]{grid.getDouble(0, i), grid.getDouble(0, j)}));
            }
        }

        int batch_counter = 0;

        DataSet trDataSet;

       /* RecordReader recordReaderTest = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReaderTest.initialize(new FileSplit(new ClassPathResource(dataSetName + "_test.csv").getFile()));*/

        /*DataSetIterator iterTest = new RecordReaderDataSetIterator(recordReaderTest, batchSizePred, labelIndex, numClasses);*/

        Collection<INDArray> outFeat;

        INDArray out;
        INDArray soften_labels_fake = Nd4j.randn(batchSizePerWorker, 1).muli(0.05);
        INDArray soften_labels_real = Nd4j.randn(batchSizePerWorker, 1).muli(0.05);

        DataSetIterator iterTrain = new MnistDataSetIterator(batchSizePerWorker, true, 12345);
        DataSetIterator iterTest = new MnistDataSetIterator(batchSizePerWorker, true, 12345);
        MultiDataSet trainDataList= new MultiDataSet();

        while (iterTrain.hasNext() && batch_counter < numIterations) {
            //trainDataList.clear();
            trDataSet = iterTrain.next();
           /* INDArray real = trDataSet.getFeatures();
            INDArray fake = gen.feedForward(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0),false).get("gen_conv2d_8");
            INDArray fakeLabel = Nd4j.zeros(batchSizePerWorker, 1).addi(soften_labels_fake);
            // This is real data...
            // [Fake, Real].

            MultiDataSet dataSetD1 = new MultiDataSet( new INDArray[] {real}
                    ,new INDArray[] {Nd4j.ones(batchSizePerWorker, 1).addi(soften_labels_real)});
            // ...and this is fake data.
            // [Fake, Real].
            MultiDataSet dataSetD2 = new MultiDataSet( new INDArray[] {fake}
                    ,new INDArray[] {fakeLabel });
            // Unfrozen discriminator is trying to figure itself out given a frozen generator.
            log.info("Training discriminator!");
            INDArray interpolatedImg = randomWeightedAverage(batch_counter,real,fake);
            *//*MultiDataSet dataSetD3 = new MultiDataSet( new INDArray[] {interpolatedImg}
                    ,new INDArray[] {fakeLabel});*//*
            dis.fit(dataSetD1);
            dis.fit(dataSetD2);*/
            /*dis.fit(dataSetD3);*/
            dis.fit(new DataSet(trDataSet.getFeatures(), Nd4j.ones(batchSizePerWorker, 1).addi(soften_labels_real)));
            // ...and this is fake data.
            // [Fake, Real].

            dis.fit(new DataSet(gen.output(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0))[0], Nd4j.zeros(batchSizePerWorker, 1).addi(soften_labels_fake)));
            // Update GAN's frozen discriminator with unfrozen discriminator.
            gan.getLayer("gan_dis_batch_layer_9").setParam("gamma", dis.getLayer("dis_batch_layer_1").getParam("gamma"));
            gan.getLayer("gan_dis_batch_layer_9").setParam("beta", dis.getLayer("dis_batch_layer_1").getParam("beta"));
            gan.getLayer("gan_dis_batch_layer_9").setParam("mean", dis.getLayer("dis_batch_layer_1").getParam("mean"));
            gan.getLayer("gan_dis_batch_layer_9").setParam("log10stdev", dis.getLayer("dis_batch_layer_1").getParam("log10stdev"));

            gan.getLayer("gan_dis_conv2d_layer_10").setParam("W", dis.getLayer("dis_conv2d_layer_2").getParam("W"));
            gan.getLayer("gan_dis_conv2d_layer_10").setParam("b", dis.getLayer("dis_conv2d_layer_2").getParam("b"));

            gan.getLayer("gan_dis_conv2d_layer_12").setParam("W", dis.getLayer("dis_conv2d_layer_4").getParam("W"));
            gan.getLayer("gan_dis_conv2d_layer_12").setParam("b", dis.getLayer("dis_conv2d_layer_4").getParam("b"));

            gan.getLayer("gan_dis_dense_layer_14").setParam("W", dis.getLayer("dis_dense_layer_6").getParam("W"));
            gan.getLayer("gan_dis_dense_layer_14").setParam("b", dis.getLayer("dis_dense_layer_6").getParam("b"));

            gan.getLayer("gan_dis_output_layer_15").setParam("W", dis.getLayer("dis_output_layer_7").getParam("W"));
            gan.getLayer("gan_dis_output_layer_15").setParam("b", dis.getLayer("dis_output_layer_7").getParam("b"));


            // Tell the frozen discriminator that all the fake examples are real examples.
            // [Fake, Real].
            /*MultiDataSet dataSetD5 = new MultiDataSet(new INDArray[] {Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0)}
                    , new INDArray[] {Nd4j.ones(batchSizePerWorker, 1)});
            // Unfrozen generator is trying to fool the frozen discriminator.
            log.info("Training generator!");
            //trainDataGen = sc.parallelize(trainDataList);
            gan.fit(dataSetD5);*/
            gan.fit(new DataSet(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0), Nd4j.ones(batchSizePerWorker, 1)));
            // Update frozen generator with GAN's unfrozen generator.
            gen.getLayer("gen_batch_1").setParam("gamma", gan.getLayer("gan_batch_1").getParam("gamma"));
            gen.getLayer("gen_batch_1").setParam("beta", gan.getLayer("gan_batch_1").getParam("beta"));
            gen.getLayer("gen_batch_1").setParam("mean", gan.getLayer("gan_batch_1").getParam("mean"));
            gen.getLayer("gen_batch_1").setParam("log10stdev", gan.getLayer("gan_batch_1").getParam("log10stdev"));

            gen.getLayer("gen_dense_layer_2").setParam("W", gan.getLayer("gan_dense_layer_2").getParam("W"));
            gen.getLayer("gen_dense_layer_2").setParam("b", gan.getLayer("gan_dense_layer_2").getParam("b"));

            gen.getLayer("gen_dense_layer_3").setParam("W", gan.getLayer("gan_dense_layer_3").getParam("W"));
            gen.getLayer("gen_dense_layer_3").setParam("b", gan.getLayer("gan_dense_layer_3").getParam("b"));

            gen.getLayer("gen_batch_4").setParam("gamma", gan.getLayer("gan_batch_4").getParam("gamma"));
            gen.getLayer("gen_batch_4").setParam("beta", gan.getLayer("gan_batch_4").getParam("beta"));
            gen.getLayer("gen_batch_4").setParam("mean", gan.getLayer("gan_batch_4").getParam("mean"));
            gen.getLayer("gen_batch_4").setParam("log10stdev", gan.getLayer("gan_batch_4").getParam("log10stdev"));

            gen.getLayer("gen_conv2d_6").setParam("W", gan.getLayer("gan_conv2d_6").getParam("W"));
            gen.getLayer("gen_conv2d_6").setParam("b", gan.getLayer("gan_conv2d_6").getParam("b"));

            gen.getLayer("gen_conv2d_8").setParam("W", gan.getLayer("gan_conv2d_8").getParam("W"));
            gen.getLayer("gen_conv2d_8").setParam("b", gan.getLayer("gan_conv2d_8").getParam("b"));

            log.info("Training computer vision model!");
          /*  cv.getLayer("dis_batch_layer_1").setParam("gamma", dis.getLayer("dis_batch_layer_1").getParam("gamma"));
            cv.getLayer("dis_batch_layer_1").setParam("beta", dis.getLayer("dis_batch_layer_1").getParam("beta"));
            cv.getLayer("dis_batch_layer_1").setParam("mean", dis.getLayer("dis_batch_layer_1").getParam("mean"));
            cv.getLayer("dis_batch_layer_1").setParam("log10stdev", dis.getLayer("dis_batch_layer_1").getParam("log10stdev"));

            cv.getLayer("dis_conv2d_layer_2").setParam("W", dis.getLayer("dis_conv2d_layer_2").getParam("W"));
            cv.getLayer("dis_conv2d_layer_2").setParam("b", dis.getLayer("dis_conv2d_layer_2").getParam("b"));

            cv.getLayer("dis_conv2d_layer_4").setParam("W", dis.getLayer("dis_conv2d_layer_4").getParam("W"));
            cv.getLayer("dis_conv2d_layer_4").setParam("b", dis.getLayer("dis_conv2d_layer_4").getParam("b"));

            cv.getLayer("dis_dense_layer_6").setParam("W", dis.getLayer("dis_dense_layer_6").getParam("W"));
            cv.getLayer("dis_dense_layer_6").setParam("b", dis.getLayer("dis_dense_layer_6").getParam("b"));

            //trainData = sc.parallelize(trainDataList);
            cv.fit(trDataSet);*/

            batch_counter++;
            log.info("Completed Batch {}!", batch_counter);

            if ((batch_counter % printEvery) == 0) {

               // out = gen.output(Nd4j.vstack(z))[0].reshape(numGenSamples * numGenSamples, numFeatures);
                INDArray out1 =  gen.feedForward(Nd4j.vstack(z),false).get("gen_conv2d_8");;


                INDArray[] samples = new INDArray[1];
                for (int j = 0; j < 1; j++) {
                    samples[j] = out1;
                }
                visualize(samples);
                /*FileWriter fileWriter = new FileWriter(String.format("%s%s_out_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < out.shape()[0]; i++) {
                    for (int j = 0; j < out.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(out.getDouble(i, j)));
                        if (j != out.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != out.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();*/
            }

            /*if ((batch_counter % saveEvery) == 0) {
                log.info("Ensemble of deep learners for estimation of uncertainty!");

                outFeat = new ArrayList<>();
                z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
                while (iterTest.hasNext()) {
                    outFeat.add(dis.output(iterTest.next().getFeatures())[0]);
                }

                INDArray toWrite = Nd4j.vstack(outFeat);

                *//*FileWriter fileWriter = new FileWriter(String.format("%s%s_test_predictions_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < toWrite.shape()[0]; i++) {
                    for (int j = 0; j < toWrite.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(toWrite.getDouble(i, j)));
                        if (j != toWrite.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != toWrite.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();*//*
            }
*/
            if (!iterTrain.hasNext()) {
                iterTrain.reset();
            }
        }

        log.info("Saving models!");
        ModelSerializer.writeModel(dis, new File(resPath + dataSetName + "_dis_model.zip"), true);
        ModelSerializer.writeModel(gan, new File(resPath + dataSetName + "_gan_model.zip"), true);
        ModelSerializer.writeModel(gen, new File(resPath + dataSetName + "_gen_model.zip"), true);
        /*ModelSerializer.writeModel(sparkCV.getNetwork(), new File(resPath + dataSetName + "_CV_model.zip"), true);

        tm.deleteTempFiles(sc);*/
    }

    public static INDArray randomWeightedAverage(int batch,INDArray real,INDArray fake){
        INDArray alpha = Nd4j.rand(batch, 1, 1, 1);// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
        return (alpha.mul(real)).add((Nd4j.ones(alpha.shape()).sub(alpha)).mul(fake));
    }

    private static JFrame frame;
    private static JPanel panel;

    private static void visualize(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(samples.length, 4, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        for (INDArray sample : samples) {
            if(sample == null || sample.size(0) == 0){
                continue;
            }
            panel.add(imageFromINDArray(sample));
        }

        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
        long[] shape = tensor.shape();
        for (int i = 0; i < 784; i++) {
            bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * tensor.getDouble(i)));
        }
        ImageIcon orig = new ImageIcon(bi);
        Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
                Image.SCALE_DEFAULT);
        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static JLabel imageFromINDArray(INDArray array) {
       // array = array.reshape(28, 28);
        long[] shape = array.shape();
        int height = (int)shape[2];
        int width = (int)shape[3];
        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                //System.out.println(array.getDouble(0, 0, y, x));
                int gray = (int) ((array.getDouble(0, 0, y, x)  + 1) * 127.5);

                // handle out of bounds pixel values
                gray = Math.min(gray, 255);
                gray = Math.max(gray, 0);

                image.getRaster().setSample(x, y, 0, gray);
            }
        }
        ImageIcon orig = new ImageIcon(image);
        Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
                Image.SCALE_DEFAULT);
        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
        //return image;
    }
}